#!/usr/bin/env python3
"""
Sweep over a range of lattice sizes and Wilson‑loop edge lengths.

This driver orchestrates the Volume 4 flip‑count simulator and Wilson‑loop
pipeline to systematically explore larger lattices and loop sizes.  For each
chosen lattice size it regenerates the underlying lattice and tick‑flip
statistics, updates the Volume 4 configuration to match, runs the full
simulation for U(1), SU(2) and SU(3) gauge groups and fits the resulting
Wilson‑loop averages to area‑ and perimeter‑law scalings.  A summary of the
fitted slopes is written to ``results_sweep/sweep_summary.csv`` and two
plots ``Slope_Area_vs_L.png`` and ``Slope_Perimeter_vs_L.png`` are saved
alongside the summary.

The script assumes it lives inside ``vol4-wilson-loop-pipeline_updated/scripts``
and that the sibling repositories ``vol4-flip-count-simulator_updated`` and
``ar-operator-core_updated`` reside one level up in the workspace.  No
external conda environments are required; the caller's Python interpreter is
used to run the subordinate scripts.
"""

import os
import yaml
import numpy as np
import pandas as pd
import subprocess
import matplotlib.pyplot as plt
import sys


def build_lattice(size: int, boundary: str = 'periodic') -> list:
    """Generate a list of lattice links for a 2D square lattice.

    Parameters
    ----------
    size : int
        Linear extent of the lattice (number of sites along x and y).
    boundary : str
        Boundary condition; currently only ``'periodic'`` is implemented.

    Returns
    -------
    list
        List of tuples ``((x, y), mu)`` where ``mu`` denotes the direction
        (0 for x, 1 for y).  Periodic boundaries wrap coordinates modulo
        ``size``.
    """
    links = []
    directions = [(1, 0), (0, 1)]  # offsets for mu=0 (x) and mu=1 (y)
    for x in range(size):
        for y in range(size):
            for mu, (dx, dy) in enumerate(directions):
                nx = x + dx
                ny = y + dy
                if boundary == 'periodic':
                    nx %= size
                    ny %= size
                else:
                    # Skip links that would fall off the lattice for open BCs
                    if not (0 <= nx < size and 0 <= ny < size):
                        continue
                links.append(((x, y), mu))
    return links


def main() -> None:
    """Entry point for lattice and loop sweeping.

    This function coordinates configuration updates, lattice generation,
    flip‑count regeneration, simulation runs and slope extraction across
    several lattice sizes and loop sizes.  Results are aggregated and
    visualised at the end of the sweep.
    """
    # Determine repository locations relative to this script.  We expect
    # ``sweep_lattice_and_loops.py`` to live in
    # ``vol4-wilson-loop-pipeline_updated/scripts`` so that ``workspace`` is
    # ``vol4-wilson-loop-pipeline_updated`` when ``main`` starts.
    workspace = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    flip_sim_dir = os.path.normpath(os.path.join(workspace, '..', 'vol4-flip-count-simulator_updated'))
    wilson_dir = workspace

    # Parameter sweeps
    seeds = [1]  # reserved for future use; currently unused
    lattice_sizes = [4, 6, 8]
    loop_sizes = [1, 2, 3, 4, 5]

    # Configuration file locations
    flip_cfg_path = os.path.join(flip_sim_dir, 'config.yaml')
    wilson_cfg_path = os.path.join(wilson_dir, 'config.yaml')

    # Prepare output directory
    results_dir = os.path.join(wilson_dir, 'results_sweep')
    os.makedirs(results_dir, exist_ok=True)

    summary_records = []

    for lattice_size in lattice_sizes:
        # Load and update flip‑simulator configuration.  At present there are
        # no lattice‑size‑dependent parameters in the flip‑count simulator,
        # however we reload and resave the config to ensure any edits
        # persist across iterations if added in the future.
        with open(flip_cfg_path) as f:
            flip_cfg = yaml.safe_load(f)
        # Persist unchanged
        with open(flip_cfg_path, 'w') as f:
            yaml.safe_dump(flip_cfg, f)

        # Load, update and save the Volume 4 pipeline configuration.  We
        # override the lattice size and loop sizes for the current sweep.
        with open(wilson_cfg_path) as f:
            wilson_cfg = yaml.safe_load(f)
        wilson_cfg['lattice_size'] = lattice_size
        wilson_cfg['loop_sizes'] = loop_sizes
        with open(wilson_cfg_path, 'w') as f:
            yaml.safe_dump(wilson_cfg, f)

        # Build a new lattice for this size and write it into both data
        # directories.  ``compute_Amu`` and ``generate_flip_counts`` read
        # ``lattice.npy`` from their respective ``data/`` directories.
        lattice = np.array(build_lattice(lattice_size, wilson_cfg.get('boundary_conditions', 'periodic')), dtype=object)
        # Save lattice for flip‑simulator
        flip_data_dir = os.path.join(flip_sim_dir, flip_cfg.get('data_dir', 'data'))
        os.makedirs(flip_data_dir, exist_ok=True)
        np.save(os.path.join(flip_data_dir, 'lattice.npy'), lattice)
        # Save lattice for Wilson‑loop pipeline
        wilson_data_dir = os.path.join(wilson_dir, wilson_cfg.get('data_dir', 'data'))
        os.makedirs(wilson_data_dir, exist_ok=True)
        np.save(os.path.join(wilson_data_dir, 'lattice.npy'), lattice)
        # Ensure the base kernel matches the current number of links.  The
        # Volume 4 pipeline expects ``kernel.npy`` to have length equal to
        # ``len(lattice)`` when ``kernel_path_U1`` points at this file.  If
        # an old kernel is left over from a previous sweep it may have the
        # wrong shape, leading to a ValueError in ``run_simulation.py``.
        num_links = len(lattice)
        base_kernel = np.ones((num_links,), dtype=float)
        np.save(os.path.join(wilson_data_dir, 'kernel.npy'), base_kernel)
        # Remove any SU2/SU3 kernels so that run_simulation computes new
        # matrix‑valued kernels from the updated base kernel.  Their
        # presence with mismatched shapes would cause a similar error.
        for extra in ['kernel_SU2.npy', 'kernel_SU3.npy']:
            extra_path = os.path.join(wilson_data_dir, extra)
            if os.path.exists(extra_path):
                try:
                    os.remove(extra_path)
                except OSError:
                    pass

        # Regenerate flip counts.  Use the local Python interpreter.  The
        # flip‑count simulator relies on the ``flip_operators`` module from
        # ``ar-operator-core``.  We set PYTHONPATH to point at the
        # appropriate repository so Python can import it.  To avoid polluting
        # the environment for subsequent steps we call ``subprocess.run`` with
        # an explicit environment mapping.
        env = os.environ.copy()
        env['PYTHONPATH'] = os.path.normpath(os.path.join(flip_sim_dir, '..', 'ar-operator-core_updated'))
        gen_flip_script = os.path.join(flip_sim_dir, 'generate_flip_counts.py')
        subprocess.run([sys.executable, gen_flip_script, '--config', flip_cfg_path], check=True, env=env)
        #  Run the pipeline for the current lattice and loop sizes
        subprocess.run([sys.executable, os.path.join(wilson_dir, 'run_simulation.py'), '--config', wilson_cfg_path], check=True)
        # Extract slopes from the Wilson loop averages.  Use the results
        # produced by ``run_simulation.py``.  We read the CSV files and fit
        # linear models of log(|<W>|) vs area and perimeter.
        area_slopes = {}
        perim_slopes = {}
        for G in wilson_cfg.get('gauge_groups', ['U1']):
            csv_path = os.path.join(wilson_dir, wilson_cfg.get('results_dir', 'results'), f'wilson_{G}.csv')
            df = pd.read_csv(csv_path)
            sizes_arr = df['size'].to_numpy()
            avg = df['real'].to_numpy() + 1j * df['imag'].to_numpy()
            log_avg = np.log(np.abs(avg))
            areas = sizes_arr ** 2
            perims = 4 * sizes_arr
            slope_area, _ = np.polyfit(areas, log_avg, 1)
            slope_perim, _ = np.polyfit(perims, log_avg, 1)
            area_slopes[G] = slope_area
            perim_slopes[G] = slope_perim
        summary_records.append({'L': lattice_size, 'area_slopes': area_slopes, 'perim_slopes': perim_slopes})

    # Write summary CSV
    summary_path = os.path.join(results_dir, 'sweep_summary.csv')
    # Flatten nested dicts for CSV output
    rows = []
    for rec in summary_records:
        row = {'L': rec['L']}
        for G, slope in rec['area_slopes'].items():
            row[f'area_slope_{G}'] = slope
        for G, slope in rec['perim_slopes'].items():
            row[f'perim_slope_{G}'] = slope
        rows.append(row)
    df_summary = pd.DataFrame(rows)
    df_summary.to_csv(summary_path, index=False)
    print(f'Saved sweep summary to {summary_path}')

    # Plot area and perimeter slopes vs lattice size
    for law in ['area', 'perim']:
        plt.figure()
        for G in wilson_cfg.get('gauge_groups', ['U1']):
            y = df_summary[f'{law}_slope_{G}']
            plt.plot(df_summary['L'], y, marker='o', label=G)
        plt.xlabel('Lattice size')
        plt.ylabel(f'{law.capitalize()}‑law slope')
        plt.title(f'{law.capitalize()}‑law slope vs lattice size')
        plt.legend()
        plot_path = os.path.join(results_dir, f'Slope_{law.capitalize()}_vs_L.png')
        plt.savefig(plot_path)
        plt.close()
        print(f'Saved {law}-law slopes plot to {plot_path}')


if __name__ == '__main__':
    main()